import imageio
import logging
import os
from timeit import default_timer
from collections import defaultdict

from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import trange
import torch
from torch.nn import functional as F

from disvae.models.losses import _reconstruction_loss
from disvae.utils.modelIO import save_model
import wandb

TRAIN_LOSSES_LOGFILE = "train_losses.log"


class Trainer():
    """
    Class to handle training of model.

    Parameters
    ----------
    model: disvae.vae.VAE

    optimizer: torch.optim.Optimizer

    loss_f: disvae.models.BaseLoss
        Loss function.

    device: torch.device, optional
        Device on which to run the code.

    logger: logging.Logger, optional
        Logger.

    save_dir : str, optional
        Directory for saving logs.

    gif_visualizer : viz.Visualizer, optional
        Gif Visualizer that should return samples at every epochs.

    is_progress_bar: bool, optional
        Whether to use a progress bar for training.
    """

    def __init__(self, model, optimizer, loss_f,
                 device=torch.device("cpu"),
                 logger=logging.getLogger(__name__),
                 save_dir="results",
                 gif_visualizer=None,
                 is_progress_bar=True,
                 checkpoint_callback=None):

        self.device = device
        self.model = model.to(self.device)

        self.loss_f = loss_f
        self.optimizer = optimizer
        self.save_dir = save_dir
        self.is_progress_bar = is_progress_bar
        self.logger = logger
        self.losses_logger = LossesLogger(os.path.join(self.save_dir, TRAIN_LOSSES_LOGFILE))
        self.gif_visualizer = gif_visualizer
        self.logger.info("Training Device: {}".format(self.device))
        self.checkpoint_callback = checkpoint_callback
        self.stopped = False

    def early_stop(self, storer):

        return False

    def __call__(self, data_loader,
                 epochs=10,
                 checkpoint_every=10):
        """
        Trains the model.

        Parameters
        ----------
        data_loader: torch.utils.data.DataLoader

        epochs: int, optional
            Number of epochs to train the model for.

        checkpoint_every: int, optional
            Save a checkpoint of the trained model every n epoch.
        """
        start = default_timer()
        self.model.train()
        for epoch in range(epochs):
            storer = defaultdict(list)
            mean_epoch_loss = self._train_epoch(data_loader, storer, epoch)
            self.logger.info('Epoch: {} Average loss per image: {:.2f}'.format(epoch + 1,
                                                                               mean_epoch_loss))
            self.losses_logger.log(epoch, storer)

            if (epoch) % max(epochs // 20, 1) == 0 or epoch == epochs - 1:
                # The number of visualized models is up to 20
                if self.gif_visualizer is not None:
                    self.gif_visualizer()

            if (epoch + 1) % checkpoint_every == 0:
                save_model(self.model, self.save_dir,
                           filename="model-{}.pt".format(epoch))
                if not (self.checkpoint_callback is None):
                    self.checkpoint_callback()

            if self.early_stop(storer):
                break

        self.model.eval()

        delta_time = (default_timer() - start) / 60
        self.logger.info('Finished training after {:.1f} min.'.format(delta_time))

    def _train_epoch(self, data_loader, storer, epoch):
        """
        Trains the model for one epoch.

        Parameters
        ----------
        data_loader: torch.utils.data.DataLoader

        storer: dict
            Dictionary in which to store important variables for vizualisation.

        epoch: int
            Epoch number

        Return
        ------
        mean_epoch_loss: float
            Mean loss per image
        """
        epoch_loss = 0.
        kwargs = dict(desc="Epoch {}".format(epoch + 1), leave=False,
                      disable=not self.is_progress_bar)
        # with trange(len(data_loader), **kwargs) as t:
        for itr, data in enumerate(data_loader):
            iter_loss = self._train_iteration(data, storer)
            epoch_loss += iter_loss

            # t.set_postfix(loss=iter_loss)
            # t.update()
            step = itr + epoch * len(data_loader)
            if (step + 1) % 10 == 0:
                ndict = {}
                for k, v in storer.items():
                    if isinstance(v, list):
                        ndict[k] = mean(v)
                ndict['itr'] = step
                wandb.log(ndict, sync=False)

        mean_epoch_loss = epoch_loss / len(data_loader)
        return mean_epoch_loss

    def _train_iteration(self, data, storer):
        """
        Trains the model for one iteration on a batch of data.

        Parameters
        ----------
        data: (imgs:torch.Tensor, label:torch.Tensor)
            A batch of data. Shape : (batch_size, channel, height, width).

        storer: dict
            Dictionary in which to store important variables for vizualisation.
        """
        imgs, labels = data
        batch_size, channel, height, width = imgs.size()

        imgs = imgs.to(self.device).float()
        labels = labels.to(self.device)

        try:
            recon_batch, latent_dist, latent_sample = self.model(imgs)
            loss = self.loss_f(imgs, recon_batch, latent_dist, self.model.training,
                               storer, latent_sample=latent_sample)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        except ValueError:
            # for losses that use multiple optimizers (e.g. Factor)
            loss = self.loss_f.call_optimize(imgs, self.model, self.optimizer, storer)

        return loss.item()


class Refiner():
    """
    Class to handle training of model.

    Parameters
    ----------
    model: disvae.vae.VAE

    optimizer: torch.optim.Optimizer

    loss_f: disvae.models.BaseLoss
        Loss function.

    device: torch.device, optional
        Device on which to run the code.

    logger: logging.Logger, optional
        Logger.

    save_dir : str, optional
        Directory for saving logs.

    gif_visualizer : viz.Visualizer, optional
        Gif Visualizer that should return samples at every epochs.

    is_progress_bar: bool, optional
        Whether to use a progress bar for training.
    """

    def __init__(self, model, optimizer, loss_f, data_loader,
                 device=torch.device("cpu"),
                 logger=logging.getLogger(__name__),
                 save_dir="results",
                 gif_visualizer=None,
                 is_progress_bar=True, ):

        self.device = device
        self.model = model.to(self.device)
        self.data_loader = data_loader

        self.loss_f = loss_f
        self.optimizer = optimizer
        self.save_dir = save_dir
        self.is_progress_bar = is_progress_bar
        self.logger = logger
        self.losses_logger = LossesLogger(os.path.join(self.save_dir, TRAIN_LOSSES_LOGFILE))
        self.gif_visualizer = gif_visualizer
        self.logger.info("Training Device: {}".format(self.device))

        self.refine_dl = None
        z = []
        x = []
        for imgs, labels in data_loader:
            with torch.no_grad():
                z.append(model.encode(imgs))
                x.append(imgs)
        z = torch.cat(z)
        x = torch.cat(x)
        self.refine_dl = DataLoader(list(zip(z, x)), batch_size=128, shuffle=True)

    def __call__(self, epochs=10):
        self.model.train()
        start = default_timer()
        for epoch in range(epochs):
            storer = defaultdict(list)
            mean_epoch_loss = self._train_epoch(self.refine_dl, storer, epoch)
            self.logger.info('Epoch: {} Average loss per image: {:.2f}'.format(epoch + 1,
                                                                               mean_epoch_loss))
            self.losses_logger.log(epoch, storer)

            if self.gif_visualizer is not None:
                self.gif_visualizer()

        self.model.eval()

        delta_time = (default_timer() - start) / 60
        self.logger.info('Finished training after {:.1f} min.'.format(delta_time))

    def _train_epoch(self, data_loader, storer, epoch):
        """
        Trains the model for one epoch.

        Parameters
        ----------
        data_loader: torch.utils.data.DataLoader

        storer: dict
            Dictionary in which to store important variables for vizualisation.

        epoch: int
            Epoch number

        Return
        ------
        mean_epoch_loss: float
            Mean loss per image
        """
        epoch_loss = 0.
        kwargs = dict(desc="Epoch {}".format(epoch + 1), leave=False,
                      disable=not self.is_progress_bar)
        with trange(len(data_loader), **kwargs) as t:
            for itr, data in enumerate(t):
                iter_loss = self._train_iteration(data, storer)
                epoch_loss += iter_loss

                t.set_postfix(loss=iter_loss)
                t.update()
                step = itr + epoch * len(data_loader)
                if (step + 1) % 100 == 0:
                    ndict = {}
                    for k, v in storer.items():
                        if isinstance(v, list):
                            ndict[k] = mean(v)
                    ndict['itr'] = step
                    wandb.log(ndict, sync=False)

        mean_epoch_loss = epoch_loss / len(data_loader)
        return mean_epoch_loss

    def _train_iteration(self, data, storer):
        zs, imgs = data
        batch_size, channel, height, width = imgs.size()

        imgs = imgs.to(self.device)
        zs = zs.to(self.device)

        recon_batch = self.model.decoder(zs)
        loss = _reconstruction_loss(imgs, recon_batch, storer=storer)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()


class LossesLogger(object):
    """Class definition for objects to write data to log files in a
    form which is then easy to be plotted.
    """

    def __init__(self, file_path_name):
        """ Create a logger to store information for plotting. """
        if os.path.isfile(file_path_name):
            os.remove(file_path_name)

        self.logger = logging.getLogger("losses_logger")
        self.logger.setLevel(1)  # always store
        file_handler = logging.FileHandler(file_path_name)
        file_handler.setLevel(1)
        self.logger.addHandler(file_handler)

        header = ",".join(["Epoch", "Loss", "Value"])
        self.logger.debug(header)

    def log(self, epoch, losses_storer):
        """Write to the log file """
        for k, v in losses_storer.items():
            log_string = ",".join(str(item) for item in [epoch, k, mean(v)])
            self.logger.debug(log_string)


# HELPERS
def mean(l):
    """Compute the mean of a list"""
    return sum(l) / len(l)
